﻿using Microsoft.ProgramSynthesis.NLToCode;
using Microsoft.ProgramSynthesis.Utils;
using Microsoft.ProgramSynthesis.Wrangling.Schema.TableOutput;
using Microsoft.ProgramSynthesis.NLToM.Tests;
using Newtonsoft.Json;
using Microsoft.SqlServer.Server;
using Microsoft.ProgramSynthesis.Specifications;
using System.Data;
using System.Text.RegularExpressions;
using Microsoft.ProgramSynthesis.NLToM;
using Microsoft.Mashup.DocumentServices;
using Microsoft.ProgramSynthesis.Detection.FileType;
using System.Runtime.InteropServices;

namespace NLToCode.Evaluation {
    public class Utils {
        private static async Task<List<Tuple<string, ITable<object>>>> ComputeProgramOutputsAsync(
            IProgramExecutor programExecutor,
            List<string> programs,
            ITable<object> inputTable,
            string inputTableName,
            string outputTableName) {
            List<Tuple<string, ITable<object>>> programToProgOutputMap = new();

            foreach (string program in programs) {
                ITable<object> outputTable = await programExecutor.ExecuteProgramOnTableAsync(program,
                    inputTable,
                    inputTableName,
                    outputTableName);
                if (outputTable != null) {
                    programToProgOutputMap.Add(new Tuple<string, ITable<object>>(program, outputTable));
                }
            }

            return programToProgOutputMap;
        }

        public static string GetCorrectTableName(string groundTruth, string predCode) {
            string pattern = @"\(([^,]+),";
            var match = Regex.Match(groundTruth, pattern);
            if (match.Groups.Count > 1) {
                string tableName = match.Groups[1].Value;
                var matchPred = Regex.Match(predCode, pattern);
                if (matchPred.Groups.Count > 1) {
                    predCode = predCode.Replace(matchPred.Groups[1].Value, tableName);
                    return predCode;
                }
            }
            return predCode;
        }

        public async static Task<List<Benchmark>> ExecuteBenchmarks(List<Benchmark> benchmarks, Configuration configuration) {
            foreach (Benchmark benchmark in benchmarks) {
                Console.WriteLine(benchmark.Id);
                ITable<object> table = GetTable(benchmark);
                string groundTruth = benchmark.Answer;
                Session session = null;
                if (configuration.Language == Configuration.LanguageMode.Pandas) {
                    //programExecutor = new PythonRpcServer(new TestLogger());
                }
                else {
                    session = new MSession();
                    session.ProgramExecutor = new MExecutionEngine();
                    session.PostprocessingMode = PostprocessingConfig.ExecutionBasedWithInterleaving;
                    session.InputTable = table;
                    session.InputTableName = benchmark.InputTable;
                }
                var groundTruthOutput = await session.ProgramExecutor.ExecuteProgramOnTableAsync(
                    groundTruth,
                    table,
                    benchmark.InputTable
                    );

                benchmark.GroundTruthOutput = new Tuple<string, TableData>(groundTruth, ConvertToTableData(groundTruthOutput));
                if (benchmark.PredictedCode != null) {
                    var outputs = await ComputeProgramOutputsAsync(session.ProgramExecutor, benchmark.PredictedCode, table, benchmark.InputTable, ""); //Add output table later
                    benchmark.PredictedOutputs = outputs.Select(o => new Tuple<string, TableData>(o.Item1, ConvertToTableData(o.Item2))).ToList();
                }

            }
            return benchmarks;
        }

        public static List<string> Tokenize(string s) {
            char[] separators = { ',', '.', '{', '}', '(', ')' };
            return s.Trim(separators).Split(separators).ToList();
        }

        public static ITable<object> GetTable(Benchmark benchmark) {
            return Microsoft.ProgramSynthesis.NLToCode.Tests.Utils.ParseInputTable(benchmark.Dataset.Values, benchmark.Dataset.ColumnTypes.Split(',').ToList(), benchmark.Dataset.Header);
        }

        public static ITable<object> GetTable(TableData table) {
            if (table.ColumnTypes == null || table.Values == null) { return null; }
            try {
                return Microsoft.ProgramSynthesis.NLToCode.Tests.Utils.ParseInputTable(table.Values, table.ColumnTypes.Split(',').ToList(), table.Header);
            }
            catch (Exception) {
                return (ITable<object>) (null);
            }
        }

        public static string PlotMetrics(List<Benchmark> benchmarks, Configuration configuration) {
            var exactMatch = new List<int>();
            var semanticMatch = new List<int>();
            var brokenBenchmarks = new List<int>();
            var k = benchmarks.Select(b => b.PredictedCode.Count).Max();
            
            foreach (Benchmark benchmark in benchmarks) {
                Console.WriteLine(benchmark.Id);
                
                Tuple<int, int> exactAndSemMatch = FindExactAndSemanticMatch(benchmark);
                if (exactAndSemMatch != null) {
                    //Check for exact match
                    if (exactAndSemMatch.Item1 != -1) {
                        exactMatch.Add(exactAndSemMatch.Item1);
                    }
                    //Check for semantic match
                    if (exactAndSemMatch.Item2 != -1) {
                        semanticMatch.Add(exactAndSemMatch.Item2);
                    }
                    if (exactAndSemMatch.Item1 == -1 && exactAndSemMatch.Item2 == -1) { brokenBenchmarks.Add(benchmark.Id); }
                    if (exactAndSemMatch.Item2 == 1) {
                        benchmark.ResultFoundTop1 = true;
                    }
                    else {
                        benchmark.ResultFoundTop1 = false;
                    }
                }
            }
            ExactAndSemanticMatchResults results = new ExactAndSemanticMatchResults(exactMatch, semanticMatch, k, benchmarks.Count);
            return JsonConvert.SerializeObject(results, Formatting.Indented);

        }

        public static Tuple<int, int> FindExactAndSemanticMatch(Benchmark benchmark) {
            string groundTruth = benchmark.Answer;
            TableData groundTruthTableData = benchmark.GroundTruthOutput.Item2;
            int exactMatchIndex = -1, semanticMatchIndex = -1;
            bool matchColumn = benchmark.MatchColumns == "FALSE" ? false : true;
            List<string> predictedCode = benchmark.PredictedCode;
            string RemoveWhitespace(string s) {
                return String.Concat(s.Where(c => !Char.IsWhiteSpace(c)));
            }
            //Checking for equal match with groundtruth
            if (predictedCode.Where(p=> p!=null).Any(r => RemoveWhitespace(r).Equals(RemoveWhitespace(groundTruth)))) {
                exactMatchIndex = predictedCode.IndexOf(predictedCode.Where(r => RemoveWhitespace(r).Equals(RemoveWhitespace(groundTruth))).First()) + 1;
            }
            ITable<object> inputTable = GetTable(benchmark);
            ITable<object> groundTruthTable = GetTable(groundTruthTableData);
            int predCount = 1;
            foreach (var predTableData in benchmark.PredictedOutputs) {
                ITable<object> predictedTable = GetTable(predTableData.Item2);
                if (predictedTable != null && groundTruthTable != null && Session.TableEquivalence.Equals(predictedTable, groundTruthTable) && (!matchColumn || (matchColumn && groundTruthTable.ColumnNames.SequenceEqual(predictedTable.ColumnNames)))) {
                    semanticMatchIndex = predCount;
                    break;
                }
                predCount++;
            }
            return new Tuple<int, int>(exactMatchIndex, semanticMatchIndex);
        }

        public static TableData ConvertToTableData(ITable<object> table) {
            TableData tableData = new TableData();
            if (table == null) { return tableData; }
            tableData.Header = table.ColumnNames.ToList();
            if (table.Count() == 0) { return tableData; }
            var firstRow = table.First();
            var colTypes = new List<string>();
            foreach (var col in firstRow) {
                string colType = "string";
                if (col != null) {
                    colType = col.GetType().Name;
                }

                if (colType.Contains("Int")) { colTypes.Add("int"); }
                else if (colType.Contains("Double")) { colTypes.Add("double"); }
                else if (colType.Contains("DateTime")) { colTypes.Add("DateTime"); }
                else { colTypes.Add("string"); }
            }
            List<List<string>> values = table.Select(row => row.Select(cell => cell?.ToString() ?? "").ToList()).ToList();
            tableData.Values = values;
            tableData.ColumnTypes = string.Join(",", colTypes);
            return tableData;
        }

        public static void WriteBenchmark(string dataset, Benchmark benchmark, string outputPath) {
            var jsonData = JsonConvert.SerializeObject(benchmark);
            using (StreamWriter sw = new StreamWriter(outputPath + "/" + dataset + "_" + benchmark.Id + ".json")) {
                sw.Write(jsonData);
            }

        }

        public class Step {
            public string Name { get; set; }
            public string Code { get; set; }
            public Step(string name, string code) {
                Name = name;
                Code = code;
            }
        }

        public static List<Step> UnrollQuery(string query) {
            try {
                QueryState queryState = QueryStateUtils.ParseQueryScript("", query.Trim(new [] {';'}));
                return queryState.Steps.Select(step => new Step(step.NameIdentifier.Name.ToString(), step.MashupScript)).ToList();
            }
            catch (Exception ex) {
                Console.WriteLine(ex.Message);
            }
            return null;
        }

    }
}
